module project_routines

# ---------------------------------------------------------
# using Parameters, FileIO, CSV, LaTeXStrings, Plots, PGFPlotsX, Distributed
using Random: MersenneTwister, randn
using Statistics: mean, std, cor, I, kron, UpperTriangular, median
using DataFrames
using Parameters
using LaTeXStrings
using Plots
# ---------------------------------------------------------


# ---------------------------------------------------------
export parameters, parameters_TradeWar,
    utilityvars_autarky, utilityvars_complete, utilityvars_incomplete
export random_productivity_state
export solver_CompleteMarkets
export solver_CE
export solver_incomplete

export TradeWar_USChina

export bilateral_trade_deal_cf
export bilateral_trade_deal_analyze

export no_diag_mean
export myplotter_draft

export triangle_inequality
# ---------------------------------------------------------


# ---------------------------------------------------------
## Shocks and States
"""
    shocks
    Structure to contain N productivity paths and a time length
"""
struct shocks
    zj::Array{Float64,2}
    T::Int64
end

"""
Structure to contain parameters for a main simulation.

# Arguments
- `eta::Float64`: Elasticity of substitution
- `gamma::Float64`: Relative risk-aversion
- `N::Int64`: Number of countries
- `dt::Float64`: Time step as ratio of year; e.g., 1/12 for monthly
- `Lj::Vector{Float64}`: Vector of size N, relative country sizes.
- `lambdaj::Array{Float64}`: Vector of size N, relative Pareto Weights.
- `Tau::Array{Float64,2}`: NxN matrix of trade costs.
- `stepsize::Float64`: Relative step size for iteration of prices
- `tol::Float64`: toerance for ``excess demand = 0`` equation to maintain
"""
struct parameters
    eta::Float64
    gamma::Float64
    N::Int64
    dt::Float64
    Lj::Vector{Float64}
    lambdaj::Array{Float64}
    Tau::Array{Float64,2}
    stepsize::Float64
    tol::Float64
end

"Structure to contain parameters for the trade war counterfactual test."
struct parameters_TradeWar
    eta::Float64
    gamma::Float64
    N::Int64
    dt::Float64
    phi::Float64
    Lj::Vector{Float64}
    lambdaj::Array{Float64}
    Tau::Array{Float64,2}
    stepsize::Float64
    tol::Float64
end


# --- 
## Autarky Solvers
"Mutable structure for Autarky market solvers, used for (mutated during) excess demand iteration"
@with_kw mutable struct utilityvars_autarky
    N::Int64
    T::Int64
    dt::Float64
    pjt::Vector{Float64} = ones(N)
    H::Vector{Float64} = ones(N)
    R::Array{Float64,2} = ones(N, N)
    pj::Array{Float64,2} = ones(N, T)
end
# uvars_autarky = utilityvars_autarky(N=10,T=50)

# ---
## Complete Market Solvers
"Mutable structure for complete market solvers, used for (mutated during) excess demand iteration"
@with_kw mutable struct utilityvars_complete
    N::Int64
    T::Int64
    dt::Float64
    pjt::Array{Float64,2} = ones(N, 1)
    H::Array{Float64,2} = ones(N, 1)
    Pj::Array{Float64,2} = ones(N, 1)
    pj::Array{Float64,2} = ones(N, T)
end
# uvars_complete = utilityvars_complete(N=10,T=50)

# ---
## InComplete Market Solvers
"Mutable structure for incomplete market solvers, used for (mutated during) excess demand iteration"
@with_kw mutable struct utilityvars_incomplete
    N::Int64
    T::Int64
    dt::Float64
    pjt::Array{Float64,2} = ones(N, 1)
    H::Array{Float64,2} = ones(N, 1)
    Pj::Array{Float64,2} = ones(N, 1)
    pj::Array{Float64,2} = ones(N, T)
end
# uvars_incomplete = utilityvars_incomplete(N=10,T=50)
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    solver_CompleteMarkets(calibration, states, uvars)

Solve the complete market economy for a calibration and state vector.

The output is a vector of five moments, averaged across countries or country-pairs.

# Examples
```jldoctest
julia> Complete_output = solver_CompleteMarkets(calibrationComplete, states, uvars_complete)
5-element Vector{Float64}:
 0.022
 0.119
 0.887
 0.356
 5.972
```
"""
function solver_CompleteMarkets(calibration, states, uvars)
    @unpack zj, T = states
    @unpack eta, gamma, N, Lj, lambdaj, Tau, stepsize, tol = calibration
    lambdaj_pp_gamma = lambdaj;

    # First iteration
    while maximum(abs.(uvars.H)) > tol
        uvars.H =
            H_PP_excess_demand(uvars.pjt, @view(zj[:, 1]), calibration, uvars)
        uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
    end
    uvars.pj[:, 1] = uvars.pjt

    @inbounds for s = 2:T
        uvars.H = ones(N, 1)
        uvars.pjt .= @view(uvars.pj[:, s-1])
        while maximum(abs.(uvars.H)) > tol
            uvars.H = H_PP_excess_demand(
                uvars.pjt,
                @view(zj[:, s]),
                calibration,
                uvars,
            )
            uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
        end
        uvars.pj[:, s] = uvars.pjt
    end

    pj = uvars.pj ./ repeat(uvars.pj[44, :]' .* zj[44, :]', N, 1)
    wj = zj .* pj
    Pj = ones(N, T)
    Sij = ones(N, N, T)
    for s = 1:T
        Pj[:, s] =
            sum((Tau .* pj[:, s]') .^ (1 - eta), dims = 2) .^ (1 / (1 - eta))
        Sij[:, :, s] = Pj[:, s] ./ Pj[:, s]'
    end
    Cj = (Pj ./ lambdaj_pp_gamma) .^ (-1 / gamma)
    Yj = wj .* Lj

    T_dt = Int(T * uvars.dt)
    dt_inv = Int(1 / uvars.dt)    

    Cjann = zeros(Float64, N, T_dt)
    Yjann = zeros(Float64, N, T_dt)
    for j = 1:N
        Cjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Cj[j, :]
        Yjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Yj[j, :]
    end
    dcj = diff(log.(Cjann), dims = 2)
    dyj = diff(log.(Yjann), dims = 2)
    vol_dcj = mean(std(dcj, dims = 2))
    vol_dyj = mean(std(dyj, dims = 2))
    cor_dcj = cor(dcj, dims = 2)
    cor_dyj = cor(dyj, dims = 2)
    mean_cor_dcj = mean(no_diag_mean(cor_dcj[1:43, 1:43], 2))
    mean_cor_dyj = mean(no_diag_mean(cor_dyj[1:43, 1:43], 2))

    # Exchange rates
    dsij = diff(log.(Sij), dims = 3)
    Sijann = Sij[:, :, 1:dt_inv:end]
    dsij_ann = diff(log.(Sijann), dims = 3)

    Sigma = reshape(std(dsij, dims = 3), N, N)
    vol_dsij = mean(vec(remove_diagonal(Sigma))) * sqrt(dt_inv)

    BS_coef = ones(N, N)
    dc_period = diff(log.(Cj), dims = 2)
    for j = 1:N
        for i = 1:N
            if j != i
                x_var_BS =
                    [ones(Float64, size(dcj, 2), 1) vec(dcj[i, :] - dcj[j, :])]
                y_var_BS = vec(dsij_ann[j, i, :])
                ctemp = x_var_BS \ y_var_BS
                BS_coef[j, i] = ctemp[2]
            end
        end
    end

    mean_BS_coeff = mean(BS_coef[Int.(I(N)).==0])
    # Calibration targets
    Complete_output =
        round.(
            [vol_dcj, vol_dsij, mean_cor_dcj, mean_cor_dyj, mean_BS_coeff];
            digits = 3,
        )

    # return uvars.pj, Complete_output
    return Complete_output
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    solver_CE(calibration, states, uvars; Analysis_out = false)

Solve the Autarky market economy for a calibration and state vector.

If ``Analysis_out = false``, the output is a vector of five moments, averaged across
countries or country-pairs.

If ``Analysis_out = true``, the output is a ``Tuple{Matrix{Float64}, Vector{Float64}}``,
where the first matrix is Pareto weights of the economy over periods. The vector is same
as ``Analysis_out = false``.

# Examples
```jldoctest
julia> lambdaj_ce_scld, Autarky_output = solver_CE(calibrationAut, states, uvars_autarky;
  Analysis_out = true)
44×600 Matrix{Float64}:
   1.11458e-12  2.59759e-12  1.79469e-12  …  9.19599e-13  7.08931e-13
   1.51056e-16  2.02387e-16  1.76841e-15     1.72921e-14  2.12052e-14
   4.86611e-15  1.50542e-14  7.76642e-15     5.33859e-13  2.11221e-13
   ⋮                                      ⋱
   3.18514e-16  2.38577e-16  6.4497e-17      8.45127e-17  2.47297e-17
   0.000119847  2.83233e-5   5.51917e-5      0.00051538   0.000516039
5-element Vector{Float64}:
 0.274
 0.035
 0.007
 0.512
 0.048
```
"""
function solver_CE(calibration, states, uvars; Analysis_out = false)
    # load data
    @unpack zj, T = states
    @unpack eta, gamma, N, Lj, Tau, stepsize, tol = calibration

    # First iteration
    while maximum(abs.(uvars.H)) > tol
        uvars.H = H_CE_excess_demand(uvars.pjt, zj[:, 1], calibration, uvars)
        uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
    end
    uvars.pj[:, 1] = uvars.pjt

    for s = 2:T
        uvars.H = ones(N)
        uvars.pjt = uvars.pj[:, s-1]
        while maximum(abs.(uvars.H)) > tol
            uvars.H =
                H_CE_excess_demand(uvars.pjt, zj[:, s], calibration, uvars)
            uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
        end
        uvars.pj[:, s] = uvars.pjt
    end

    pj = uvars.pj ./ repeat(uvars.pj[44, :]' .* zj[44, :]', N, 1)
    wj = zj .* pj
    Pj = ones(N, T)

    for s = 1:T
        Pj[:, s] =
            sum((Tau .* pj[:, s]') .^ (1 - eta), dims = 2) .^ (1 / (1 - eta))
    end
    Yj = wj .* Lj
    Cj = Yj ./ Pj
    lambdaj_ce = Pj ./ Cj .^ (-gamma)
    lambdaj_ce_scld = lambdaj_ce ./ (sum(lambdaj_ce, dims = 1) * N^2)

    T_dt = Int(T * uvars.dt)
    dt_inv = Int(1 / uvars.dt)    

    if Analysis_out
        Sij = ones(N, N, T)
        for s = 1:T
            Sij[:, :, s] = Pj[:, s] ./ Pj[:, s]'
        end

        Cjann = zeros(Float64, N, T_dt)
        Yjann = zeros(Float64, N, T_dt)
        for j = 1:N
            Cjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Cj[j, :]
            Yjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Yj[j, :]
        end
        dcj = diff(log.(Cjann), dims = 2)
        dyj = diff(log.(Yjann), dims = 2)
        vol_dcj = mean(std(dcj, dims = 2))
        vol_dyj = mean(std(dyj, dims = 2))
        cor_dcj = cor(dcj, dims = 2)
        cor_dyj = cor(dyj, dims = 2)
        mean_cor_dcj = mean(no_diag_mean(cor_dcj[1:43, 1:43], 2))
        mean_cor_dyj = mean(no_diag_mean(cor_dyj[1:43, 1:43], 2))

        # Exchange rates
        dsij = diff(log.(Sij), dims = 3)
        Sijann = Sij[:, :, 1:dt_inv:end]
        dsij_ann = diff(log.(Sijann), dims = 3)

        Sigma = reshape(std(dsij, dims = 3), N, N)
        vol_dsij = mean(vec(remove_diagonal(Sigma))) * sqrt(dt_inv)

        # Base factors
        basej = ones(N, T - 1)
        for s = 1:T-1
            temp = Array{Float64,2}(undef, N, N)
            temp = remove_diagonal(dsij[:, :, s])
            basej[:, s] = mean(temp, dims = 2)
        end

        # BS Regressions
        BS_coef = ones(N, N)
        for j = 1:N
            for i = 1:N
                if j != i
                    x_var_BS = [ones(Float64, size(dcj, 2), 1) vec(
                        dcj[i, :] - dcj[j, :],
                    )]
                    y_var_BS = vec(dsij_ann[j, i, :])
                    ctemp = x_var_BS \ y_var_BS
                    BS_coef[j, i] = ctemp[2]
                end
            end
        end
        mean_BS_coeff = mean(BS_coef[Int.(I(N)).==0])

        Autarky_output =
            round.(
                [vol_dcj, vol_dsij, mean_cor_dcj, mean_cor_dyj, mean_BS_coeff];
                digits = 3,
            )

        return lambdaj_ce_scld, Autarky_output
    end
    return lambdaj_ce_scld
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    solver_incomplete(calibrationIncomplete, states, uvars_incomplete)

Solve the incomplete market economy for a calibration and state vector.

The output is a ``Tuple{Matrix{Float64}, Dict{String, Any}}``, where the first matrix is
price of tradables over time. The vector is five moments, averaged across countries or
country-pairs.

# Examples
```jldoctest
julia> ~, dict_out = solver_incomplete(calibrationIncomplete, states, uvars_incomplete)
Dict{String, Any} with 38 entries:
  "pj"        => [2.06241 2.00882 … 1.28608 1.25759; 1.77032 2.22668 … 1.1389…
  "dc_period" => [0.0046177 0.00548773 … -0.00449541 0.00285386; -0.0163339 0…
  "vol_dcj"   => 0.0333531
  "basej"     => [-0.0106033 0.0192505 … 0.0050676 0.000154298; 0.15584 -0.03…
  "Yj"        => [0.107057 0.113641 … 0.0846667 0.083926; 0.0270664 0.0346689…
  ⋮           => ⋮
```
"""
function solver_incomplete(calibration, states, uvars)
    @unpack zj, T = states
    @unpack eta, gamma, N, Tau, Lj, lambdaj, stepsize, tol = calibration

    # First iteration
    while maximum(abs.(uvars.H)) > tol
        uvars.H = H_Inc_excess_demand(
            uvars.pjt,
            @view(zj[:, 1]),
            calibration,
            uvars,
            lambdaj[:, 1],
        )
        uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
    end
    uvars.pj[:, 1] = uvars.pjt

    @inbounds for s = 2:T
        uvars.H = ones(N, 1)
        uvars.pjt .= @view(uvars.pj[:, s-1])
        while maximum(abs.(uvars.H)) > tol
            uvars.H = H_Inc_excess_demand(
                uvars.pjt,
                @view(zj[:, s]),
                calibration,
                uvars,
                lambdaj[:, s],
            )
            uvars.pjt = uvars.pjt .* (1 .+ stepsize * uvars.H)
        end
        uvars.pj[:, s] = uvars.pjt
    end
    pj_inc = uvars.pj

    pj = pj_inc ./ repeat((pj_inc[44, :])' .* zj[44, :]', N, 1)

    wj = zj .* pj
    Pj = ones(N, T)
    Sij = ones(N, N, T)
    mij = ones(N, N, T)
    for s = 1:T
        Pj[:, s] =
            sum((Tau .* pj[:, s]') .^ (1 - eta), dims = 2) .^ (1 / (1 - eta))
        mij[:, :, s] =
            (Tau .* pj[:, s]') .^ (1 - eta) ./
            repeat(Pj[:, s] .^ (1 - eta), 1, N)
        Sij[:, :, s] = Pj[:, s] ./ Pj[:, s]'
    end
    Cj = (Pj ./ lambdaj) .^ (-1 / gamma)
    Yj = wj .* Lj

    ## Compute statistics
    T_dt = Int(T * uvars.dt)
    dt_inv = Int(1 / uvars.dt)    

    Cjann = zeros(Float64, N, T_dt)
    Yjann = zeros(Float64, N, T_dt)
    for j = 1:N
        Cjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Cj[j, :]
        Yjann[j, :] = kron(I(T_dt), ones(dt_inv))' * Yj[j, :]
    end
    dcj = diff(log.(Cjann), dims = 2)
    dyj = diff(log.(Yjann), dims = 2)
    vol_dcj = mean(std(dcj, dims = 2))
    vol_dyj = mean(std(dyj, dims = 2))
    cor_dcj = cor(dcj, dims = 2)
    cor_dyj = cor(dyj, dims = 2)

    mean_cor_dcj = mean(no_diag_mean(cor_dcj[1:43, 1:43], 2))
    mean_cor_dyj = mean(no_diag_mean(cor_dyj[1:43, 1:43], 2))

    # Import shares -- should be time aggregated
    mean_mij = mean(mij, dims = 3)
    mijbar = mean(vec(mean_mij))

    # Exchange rates
    dsij = diff(log.(Sij), dims = 3)
    Sijann = Sij[:, :, 1:dt_inv:end]
    dsij_ann = diff(log.(Sijann), dims = 3)

    Sigma = reshape(std(dsij, dims = 3), N, N)
    vol_dsij = mean(vec(remove_diagonal(Sigma))) * sqrt(dt_inv)
    # mean(Sigma,dims=2);
    dPj = diff(log.(Pj), dims = 2)
    mean_Pj = vec(mean(Pj, dims = 2))

    Πj = (Yj .^ (1.0 / (1.0 - eta))) ./ pj
    mean_Πj = vec(mean(Πj, dims = 2))

    X_ij = mij .* reshape(Yj, N, 1, T)
    EX = sum(X_ij, dims = 1)
    IM = sum(X_ij, dims = 2) # =Yj

    # verify gravity
    GRV_RHS =
        (
            (reshape(Pj, N, 1, T) .* reshape(Πj, 1, N, T)) ./
            repeat(Tau, 1, 1, T)
        ) .^ (eta - 1.0)
    GRV_LHS = mij ./ reshape(Yj, 1, N, T)

    # Base factors
    basej = ones(N, T - 1)
    dsij_row = ones(Int(N * (N - 1) / 2), T - 1)
    for s = 1:T-1
        temp = Array{Float64,2}(undef, N, N)
        temp = dsij[:, :, s]
        dsij_row[:, s] =
            filter!(x -> x ≠ 0, vec(remove_diagonal(UpperTriangular(temp))))
        temp = remove_diagonal(temp)
        basej[:, s] = mean(temp, dims = 2)
    end

    # Exchange rate factor regressions
    mijbeta = ones(N, N)
    mijR2 = ones(N, N)
    beta = ones(N, N)
    R2 = ones(N, N)
    BS_coef = ones(N, N)
    dc_period = diff(log.(Cj), dims = 2)
    for j = 1:N
        for i = 1:N
            if j != i
                mijbeta[j, i] =
                    (
                        (mean_mij[j, :] - mean_mij[i, :])' *
                        (mean_mij[j, :] - mijbar * ones(N))
                    ) / (
                        (mean_mij[j, :] - mijbar * ones(N))' *
                        (mean_mij[j, :] - mijbar * ones(N))
                    )
                mijR2[j, i] =
                    mijbeta[j, i]^2 * (
                        (mean_mij[j, :] - mijbar * ones(N))' *
                        (mean_mij[j, :] - mijbar * ones(N))
                    ) / (
                        (mean_mij[j, :] - mean_mij[i, :])' *
                        (mean_mij[j, :] - mean_mij[i, :])
                    )
                base_reg = vec(basej[j, :])
                x_var = [ones(Float64, T - 1, 1) base_reg]
                y_var = vec(dsij[j, i, :])
                b = x_var \ y_var
                beta[j, i] = b[2]
                R2[j, i] = cor(y_var, base_reg) .^ 2

                # cg diff on exchange rate change
                x_var_BS =
                    [ones(Float64, size(dcj, 2), 1) vec(dcj[i, :] - dcj[j, :])]
                y_var_BS = vec(dsij_ann[j, i, :])
                ctemp = x_var_BS \ y_var_BS
                BS_coef[j, i] = ctemp[2]
            end
        end
    end
    mean_BS_coeff = mean(BS_coef[Int.(I(N)).==0])

    Incomp_output =
        round.(
            [vol_dcj, vol_dsij, mean_cor_dcj, mean_cor_dyj, mean_BS_coeff];
            digits = 3,
        )

    out_dict = Dict(
        "calibration" => calibration,
        "states" => states,
        "pj" => pj,
        "wj" => wj,
        "Cj" => Cj,
        "Yj" => Yj,
        "Pj" => Pj,
        "mij" => mij,
        "Sij" => Sij,
        "Cjann" => Cjann,
        "Yjann" => Yjann,
        "dcj" => dcj,
        "dyj" => dyj,
        "vol_dcj" => vol_dcj,
        "vol_dyj" => vol_dyj,
        "cor_dcj" => cor_dcj,
        "cor_dyj" => cor_dyj,
        "mean_cor_dcj" => mean_cor_dcj,
        "mean_cor_dyj" => mean_cor_dyj,
        "mean_mij" => mean_mij,
        "mijbar" => mijbar,
        "dsij" => dsij,
        "Sijann" => Sijann,
        "dsij_ann" => dsij_ann,
        "Sigma" => Sigma,
        "vol_dsij" => vol_dsij,
        "mean_Pj" => mean_Pj,
        "dPj" => dPj,
        "Πj" => Πj,
        "mean_Πj" => mean_Πj,
        "X_ij" => X_ij,
        "basej" => basej,
        "beta" => beta,
        "R2" => R2,
        "BS_coef" => BS_coef,
        "dc_period" => dc_period,
        "unvar" => uvars,
        "Incomp_output" => Incomp_output,
    )

    return pj_inc, out_dict
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    H_CE_excess_demand(pj, zj, calibration, uvars)

Solve the excess demand for all tradable goods in the the Autarky market economy.

The output is a vector of size N.

# Arguments
- `pj`: the current guess of prices for tradables at this time period.
- `zj`: the state (productivity) vector of this time period.
"""
function H_CE_excess_demand(pj, zj, calibration, uvars)
    @unpack eta, N, Tau, Lj = calibration

    uvars.R = (Tau .* repeat(pj', N, 1) ./ repeat(pj, 1, N)) .^ (1 - eta)
    return vec(
        sum(
            uvars.R .* (repeat(pj, 1, N) ./ repeat(pj', N, 1)) ./
            sum(uvars.R, dims = 2) .* zj .* Lj,
            dims = 1,
        ) - (zj .* Lj)',
    )
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    H_PP_excess_demand(pj, zj, calibration, uvars)

Solve the excess demand for all tradable goods in the the complete market economy.

The output is a vector of size N.
"""
function H_PP_excess_demand(pj, zj, calibration, uvars)
    @unpack eta, gamma, lambdaj, N, Tau, Lj = calibration
    pijtemp = Array{Float64,2}(undef, N, N)
    for j = 1:N
        for i = 1:N
            pijtemp[j, i] = (Tau[j, i] .* pj[i]) .^ (1 - eta)
        end
    end
    uvars.Pj = sum(pijtemp, dims = 2) .^ (1 / (1 - eta))
    return sum(
        (Tau .* repeat(pj', N, 1) ./ repeat(uvars.Pj, 1, N)) .^ (-eta) .*
        repeat((uvars.Pj ./ lambdaj) .^ (-1 / gamma), 1, N),
        dims = 1,
    )' - (zj .* Lj)
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    H_Inc_excess_demand(pj, zj, calibration, uvars, lambdaj_vec)

Solve the excess demand for all tradable goods in the the incomplete market economy.

The output is a vector of size N.
"""
function H_Inc_excess_demand(pj, zj, calibration, uvars, lambdaj_vec)
    @unpack eta, gamma, lambdaj, N, Tau, Lj = calibration

    uvars.Pj =
        sum((Tau .* reshape(pj, 1, :)) .^ (1 - eta), dims = 2) .^
        (1 / (1 - eta))
    return sum(
        (Tau .* repeat(pj', N, 1) ./ repeat(uvars.Pj, 1, N)) .^ (-eta) .*
        repeat((uvars.Pj ./ lambdaj_vec) .^ (-1 / gamma), 1, N),
        dims = 1,
    )' - (zj .* Lj)
end
# ---------------------------------------------------------



# ---------------------------------------------------------
## Fig 4 and Fig 5 Functions
"""
    myplotter_draft(;
        X,
        Y,
        rho_loc = [],
        xx_line_bound = [],
        reg_line = false,
        y_label = "",
        x_label = "",
        title = "",
    )

Scatter plot with the style and format of figures 4 and 5 of the paper.

```jldoctest
julia> X, Y = 100 * mij_data[Int.(I(N)).==0], 100 * mij_share[Int.(I(N)).==0]
julia> p1_a = myplotter_draft(;
  X = X, Y = Y,
  rho_loc = [20, 30],
  xx_line_bound = [0, 40],
  y_label = "Model",
  title = "Import Share (pct.)" );
```
"""
function myplotter_draft(;
    X,
    Y,
    rho_loc = [],
    xx_line_bound = [],
    reg_line = false,
    y_label = "",
    x_label = "",
    title = "",
)
    p_out = scatter(
        X,
        Y,
        xtickfont = font(12),
        ytickfont = font(12),
        color = :black,
        #marker = (:circle, 4), markeralpha = .9,
        m = (4, :white, stroke(1, :black)),
        smooth = false,
        guidefont = font(14),
    )

    if size(rho_loc[:])[1] > 0
        annotate!(
            rho_loc[1],
            rho_loc[2],
            string("ρ = ", round(cor(X[:], Y[:]), digits = 3)),
        )
    end
    if size(xx_line_bound[:])[1] > 0
        plot!(x -> x, xx_line_bound[1], xx_line_bound[2], color = :red, lw = 3)
    end
    if reg_line
        scatter!(
            X,
            Y,
            color = :red,
            markersize = 0,
            marker = :transparent,
            reg = true,
            lw = 3,
        )
    end
    xlabel!(x_label)
    ylabel!(y_label)
    title!(title)

    return p_out
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    bilateral_trade_deal_cf(
        index_,
        states,
        calibration;
        TauReductionFactor,
        PrintProgress
        )

Perform one bilateral trade deal counterfactual for a pair specified in ``index_``.

# Examples
```jldoctest
julia> bilateral_trade_deal_cf([1 4 12],
        states,
        calibration;
        TauReductionFactor = 0.8,
        PrintProgress = true
        )
10-element Vector{Array}:
 [1 4 12]
 [0.02855555902478077, 0.04531713220833259]
 [0.00023102541522945, 0.0003663944280371311]
 [1.1550401607827756, 1.2550044888301572]
 [1.4243624533555095, 1.2664752909107906]
 [0.0038997086003000397, 0.006492192242217578]
 [0.2994413444129453, 0.7218359511612666]
 [1.0136813535821296, 0.9917291483391295]
 [0.716064793641649, 0.7160647936416488]
 [0.05289763340359256, 0.05289763340359255]
```
"""
function bilateral_trade_deal_cf(
    index_,
    states,
    calibration;
    TauReductionFactor
    )

  i, j = index_
  @unpack eta, gamma, N, dt, phi, Lj, lambdaj, Tau, stepsize, tol = calibration;
  lambdaj_pp_gamma = lambdaj;
  T = states.T;

  Tau_cf = copy(Tau)
  Tau_cf[i, j] = (Tau_cf[i, j] .- 1.0) * TauReductionFactor .+ 1.0
  Tau_cf[j, i] = (Tau_cf[j, i] .- 1.0) * TauReductionFactor .+ 1.0

  calibration_cf = parameters(eta, gamma, N, dt, Lj, [], Tau_cf, 0.2, 1e-5)
  uvars_autarky = utilityvars_autarky(N=N, T=T, dt=dt);
  lambdaj_ce_scld, Autarky_output = solver_CE(calibration_cf, states, uvars_autarky; 
      Analysis_out = true)
  lambdaj_cf = lambdaj_ce_scld * (1.0 - phi) .+ lambdaj_pp_gamma * phi

  uvars_incomplete = utilityvars_incomplete(N=N, T=T, dt=dt);
  calibration_incom_cf = parameters(eta, gamma, N, dt, Lj, lambdaj_cf, Tau_cf, 0.30, 1e-5)
  ~, dict_out_cf = solver_incomplete(calibration_incom_cf, states, uvars_incomplete)

  var_base = std(dict_out_cf["basej"], dims = 2)
  rho_ij_cf =
    dict_out_cf["Sigma"] ./
      (repeat(var_base, 1, N) + repeat(var_base', N, 1))
  dpii_sd = std(diff(log.(dict_out_cf["pj"]), dims = 2), dims = 2)
  phi_ij =
    dict_out_cf["mean_mij"][:, :, 1] ./
    mean(dict_out_cf["pj"], dims = 2)' .* dpii_sd'

  var_base_out = [vec(var_base)[i], vec(var_base)[j]]
  phi_ij_out = [phi_ij[i, j], phi_ij[j, i]]
  Pj_out = [dict_out_cf["mean_Pj"][i], dict_out_cf["mean_Pj"][j]]
  Πj_out = [dict_out_cf["mean_Πj"][i], dict_out_cf["mean_Πj"][j]]
  mij_out =
    [dict_out_cf["mean_mij"][i, j, 1], dict_out_cf["mean_mij"][j, i, 1]]
  R2ij_out = [dict_out_cf["R2"][i, j, 1], dict_out_cf["R2"][j, i, 1]]
  betaij_out = [dict_out_cf["beta"][i, j, 1], dict_out_cf["beta"][j, i, 1]]
  rho_out = [rho_ij_cf[i, j], rho_ij_cf[j, i]]
  sigma_out = [dict_out_cf["Sigma"][i, j, 1], dict_out_cf["Sigma"][j, i, 1]]

  output = Dict(
    "index"        => index_,
    "var_base_out" => var_base_out,
    "phi_ij_out"   => phi_ij_out,
    "Pj_out"       => Pj_out,
    "Πj_out"       => Πj_out,
    "mij_out"      => mij_out,
    "R2ij_out"     => R2ij_out,
    "betaij_out"   => betaij_out,
    "rho_out"      => rho_out,
    "sigma_out"    => sigma_out
  )

    return output
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    bilateral_trade_deal_analyze(; dict_out_org, BilateralDealsOutputs)

Analyze the results of bilateral trade deal tests.

# Examples
```jldoctest
julia> Table_11 = bilateral_trade_deal_analyze(;
  dict_out_org = dict_in,
  BilateralDealsOutputs = GenOutputs,
)
```
"""
function bilateral_trade_deal_analyze(; dict_out_org, BilateralDealsOutputs)

  N = size(dict_out_org["pj"],1)

  # preallocate for counterfactuals
  var_base_cf = nan_diag(ones(N, N))
  phi_ij_cf = nan_diag(ones(N, N))
  Pj_cf = nan_diag(ones(N, N))
  Πj_cf = nan_diag(ones(N, N))
  mij_cf = nan_diag(ones(N, N))
  R2ij_cf = nan_diag(ones(N, N))
  betaij_cf = nan_diag(ones(N, N))
  rho_cf = nan_diag(ones(N, N))
  sigma_cf = nan_diag(ones(N, N))

  for cf_econ in BilateralDealsOutputs
    i, j = cf_econ["index"]
    var_base_cf[i, j], var_base_cf[j, i] = cf_econ["var_base_out"]
    phi_ij_cf[i, j], phi_ij_cf[j, i] = cf_econ["phi_ij_out"]
    Pj_cf[i, j], Pj_cf[j, i] = cf_econ["Pj_out"]
    Πj_cf[i, j], Πj_cf[j, i] = cf_econ["Πj_out"]
    mij_cf[i, j], mij_cf[j, i] = cf_econ["mij_out"]
    R2ij_cf[i, j], R2ij_cf[j, i] = cf_econ["R2ij_out"]
    betaij_cf[i, j], betaij_cf[j, i] = cf_econ["betaij_out"]
    rho_cf[i, j], rho_cf[j, i] = cf_econ["rho_out"]
    sigma_cf[i, j], sigma_cf[j, i] = cf_econ["sigma_out"]
  end

    var_base_org = std(dict_out_org["basej"], dims = 2)
    rho_ij_org =
        dict_out_org["Sigma"] ./
        (repeat(var_base_org, 1, N) + repeat(var_base_org', N, 1))

    dpii_sd = std(diff(log.(dict_out_org["pj"]), dims = 2), dims = 2)
    phi_ij_org =
        dict_out_org["mean_mij"][:, :, 1] ./
        mean(dict_out_org["pj"], dims = 2)' .* dpii_sd'

    mij_diff = mij_cf ./ dict_out_org["mean_mij"][:, :, 1] .- 1.0
    R2_diff = R2ij_cf - dict_out_org["R2"]
    beta_diff = betaij_cf - dict_out_org["beta"]
    rho_diff = rho_cf - rho_ij_org
    sig_diff = sigma_cf - dict_out_org["Sigma"]

    # stats ------------------------------------------------------
    stat_div = no_diag_mean(phi_ij_org, 2)[:]
    central_ix = stat_div .> median(stat_div)

    table_cf_out =
        [
            sqrt(12) * tab_cfstat(dict_out_org["Sigma"], sigma_cf)
            tab_cfstat(dict_out_org["beta"], betaij_cf)
            tab_cfstat(dict_out_org["R2"], R2ij_cf)
            tab_cfstat(rho_ij_org, rho_cf)
        ]'

    table_cf_out_central =
        [
            sqrt(12) * tab_cfstat(
                dict_out_org["Sigma"][central_ix[:], central_ix[:]],
                sigma_cf[central_ix[:], central_ix[:]],
            )
            tab_cfstat(
                dict_out_org["beta"][central_ix[:], central_ix[:]],
                betaij_cf[central_ix[:], central_ix[:]],
            )
            tab_cfstat(
                dict_out_org["R2"][central_ix[:], central_ix[:]],
                R2ij_cf[central_ix[:], central_ix[:]],
            )
            tab_cfstat(
                rho_ij_org[central_ix[:], central_ix[:]],
                rho_cf[central_ix[:], central_ix[:]],
            )
        ]'

    table_cf_out_periph =
        [
            sqrt(12) * tab_cfstat(
                dict_out_org["Sigma"][.!central_ix[:], .!central_ix[:]],
                sigma_cf[.!central_ix[:], .!central_ix[:]],
            )
            tab_cfstat(
                dict_out_org["beta"][.!central_ix[:], .!central_ix[:]],
                betaij_cf[.!central_ix[:], .!central_ix[:]],
            )
            tab_cfstat(
                dict_out_org["R2"][.!central_ix[:], .!central_ix[:]],
                R2ij_cf[.!central_ix[:], .!central_ix[:]],
            )
            tab_cfstat(
                rho_ij_org[.!central_ix[:], .!central_ix[:]],
                rho_cf[.!central_ix[:], .!central_ix[:]],
            )
        ]'

    table_cf_out = [
        table_cf_out
        reshape(table_cf_out[3, :] - table_cf_out[1, :], 1, :)
        reshape(
            abs.(
                (table_cf_out[3, :] - table_cf_out[1, :]) ./
                table_cf_out[2, :],
            ),
            1,
            :,
        )
    ]

    table_cf_out_central = [
        table_cf_out_central
        reshape(table_cf_out_central[3, :] - table_cf_out_central[1, :], 1, :)
        reshape(
            abs.(
                (table_cf_out_central[3, :] - table_cf_out_central[1, :]) ./ table_cf_out_central[2, :],
            ),
            1,
            :,
        )
    ]

    table_cf_out_periph = [
        table_cf_out_periph
        reshape(table_cf_out_periph[3, :] - table_cf_out_periph[1, :], 1, :)
        reshape(
            abs.(
                (table_cf_out_periph[3, :] - table_cf_out_periph[1, :]) ./
                table_cf_out_periph[2, :],
            ),
            1,
            :,
        )
    ]

    Table_11 = DataFrame(
        round.(
            [
                table_cf_out
                table_cf_out_central
                table_cf_out_periph
            ];
            digits = 3,
        ),
        :auto,
    )

    rename!(Table_11, [:Volatility, :Factor_Loading, :R2, :Unshared_Risk])
    Var_names_t11 = repeat(
        [
            "Before: Mean"
            "Before: Stdev"
            "After: Mean"
            "Difference"
            "Difference/Stdev"
        ],
        3,
        1,
    )
    Panel_names_t11 =
        reshape(repeat(["Aggregate" "Core" "Peripheral"], 5, 1), 15, 1)
    Table_11.Statistic = vec(Var_names_t11)
    Table_11.Panel_names_t11 = vec(Panel_names_t11)
    Table_11 = Table_11[!, [6, 5, 1, 2, 3, 4]]

    return Table_11
end
# ---------------------------------------------------------


# ---------------------------------------------------------
## Table 12 Functions
## This code performs counterfactual 2: US CHINA Trade Tensions
"""
    TradeWar_USChina(countries_partners; calibration, Tau_inc_factor, dict_out_org)

Solve the counterfactual of US V China trade war.

The output is a tuple of two DataFrames, corresponding to top and bottom panels of Table 12.

# Examples
```jldoctest
julia> dt_USACHN, dt_partners = TradeWar_USChina(
    countries_and_partners;
    calibration = calibration,
    Tau_inc_factor = 0.5,
    dict_out_org = dict_in)
5×5 DataFrame
 Row │ Variable            USA_before  USA_after  China_be ⋯
     │ String              Float64     Float64    Float64  ⋯
─────┼──────────────────────────────────────────────────────
   1 │ Net Trade Cost           0.947      1.42          1 ⋯
   2 │ Inward Resistance        0.99       0.991         1
   3 │ Outward Resistance       1.09       1.09          1
   4 │ Import Share             0.027      0.012         0
   5 │ Unshared Risk            0.724      0.785         0 ⋯
11×7 DataFrame
Row │ Variables          Canada   Mexico   UK       Korea ⋯
    │ LaTeXStr…          Float64  Float64  Float64  Float ⋯
─────┼──────────────────────────────────────────────────────
  1 │ P_j                -0.44    -0.6     -0.85    -0. ⋯
  2 │ Π_j                 0.04     0.22     0.58     0.
  3 │ ω_{USA,j}          -1.06    -2.03    -1.05     0.
  4 │ ω_{CHN,j}          19.32    19.2     11.65     8.
  5 │ ω_{i,USA}           2.14     1.87     3.73     6. ⋯
  6 │ ω_{i,CHN}          11.32     7.72     4.34     2.
  7 │ \\overbar{σ_j}]     2.03     0.89     0.17     0.
  8 │ R^2_{i,USA}        13.01    12.2      0.94     2.
  9 │ R^2_{i,CHN}        29.88    28.97     4.37    13. ⋯
 10 │ ρ_{i,USA}           0.36    -1.08    -0.16     3.
 11 │ ρ_{i,CHN}           8.13     5.75     0.24     5.
```
"""
function TradeWar_USChina(countries_partners; 
    calibration, Tau_inc_factor, dict_out_org)


# debugging:
# countries_partners = countries_and_partners
# # calibration = calibration
# Tau_inc_factor = 0.5
# dict_out_org = copy(dict_in)


  @info "US-China trade war counterfactual"

  @unpack eta, gamma, N, dt, phi, Lj, lambdaj, Tau, stepsize, tol = calibration;
  lambdaj_pp_gamma = lambdaj;
  states = dict_out_org["states"];
  T = states.T;
  
  # Index of different economies in the data
  USA_num, CHN_num, USA_partners, CHN_partners = countries_partners
    #=
    USA_num = 44
    CHN_num = 7
    USA_partners = [5 25 42] # Canada, Mexico, and UK
    CHN_partners = [23 24 35] # Korea, Malaysia, and Singapore
    =#
  USACHN = [USA_num CHN_num] # USA and China
  # rest of the world: ROW
  ROW_ix = [i for i in range(1, N; step = 1) if !(i in USACHN)]
  all_partners = cat(USA_partners, CHN_partners; dims = 2)
  all_inc = cat(USACHN, all_partners; dims = 2)

    # create counterfactual trade costs matrix
  Tau_calib = copy(Tau)
  Tau_cf = copy(Tau_calib)
  Tau_cf[CHN_num, USA_num] =
        (1 + Tau_inc_factor) .* (Tau_calib[CHN_num, USA_num] .- 1.0) .+ 1.0
  Tau_cf[USA_num, CHN_num] =
        (1 + Tau_inc_factor) .* (Tau_calib[USA_num, CHN_num] .- 1.0) .+ 1.0

  ## Counterfactual Economy
  @info "Simulation: US-China trade war Counterfactual Economy"
  calibration_cf = parameters(eta, gamma, N, dt, Lj, lambdaj_pp_gamma, Tau_cf, 0.2, 1e-5)
  uvars_autarky = utilityvars_autarky(N=N, T=T, dt=dt);
  lambdaj_ce_scld_cf, Autarky_output = solver_CE(calibration_cf, states, uvars_autarky;
    Analysis_out = true)
  lambdaj_cf = lambdaj_ce_scld_cf * (1.0 - phi) .+ lambdaj_pp_gamma * phi

  # --- 
  uvars_incomplete = utilityvars_incomplete(N=N, T=T, dt=dt);
  calibrationIncomplete = parameters(eta, gamma, N, dt, Lj, lambdaj_cf, Tau_cf, 0.2, 1e-5);
  ~, dict_out_cf = solver_incomplete(calibrationIncomplete, states, uvars_incomplete)
  CF_output = dict_out_cf["Incomp_output"]

    # Effects on United States and China
    var_base = std(dict_out_org["basej"], dims = 2)
    rho_ij_or =
        dict_out_org["Sigma"] ./
        (repeat(var_base, 1, N) + repeat(var_base', N, 1))

    var_base = std(dict_out_cf["basej"], dims = 2)
    rho_ij_cf =
        dict_out_cf["Sigma"] ./
        (repeat(var_base, 1, N) + repeat(var_base', N, 1))

    CHNUSA_mom_or =
        round.(
            [
                [Tau_calib[USA_num, CHN_num] - 1 Tau_calib[CHN_num, USA_num] -
                                                 1]
                dict_out_org["mean_Pj"][USACHN]
                dict_out_org["mean_Πj"][USACHN]
                reshape(
                    [
                        dict_out_org["mean_mij"][USA_num, CHN_num, 1]
                        dict_out_org["mean_mij"][CHN_num, USA_num, 1]
                    ],
                    1,
                    :,
                )
                [rho_ij_or[USA_num, CHN_num, 1] rho_ij_or[CHN_num, USA_num, 1]]
            ];
            digits = 3,
        )

    CHNUSA_mom_cf =
        round.(
            [
                [Tau_cf[USA_num, CHN_num] - 1 Tau_cf[CHN_num, USA_num] - 1]
                dict_out_cf["mean_Pj"][USACHN]
                dict_out_cf["mean_Πj"][USACHN]
                reshape(
                    [
                        dict_out_cf["mean_mij"][USA_num, CHN_num, 1]
                        dict_out_cf["mean_mij"][CHN_num, USA_num, 1]
                    ],
                    1,
                    :,
                )
                [rho_ij_cf[USA_num, CHN_num, 1] rho_ij_cf[CHN_num, USA_num, 1]]
            ];
            digits = 3,
        )

    Var_names = [
        "Net Trade Cost",
        "Inward Resistance",
        "Outward Resistance",
        "Import Share",
        "Unshared Risk",
    ]
    dt_USACHN = DataFrame(
        Variable = Var_names,
        USA_before = CHNUSA_mom_or[:, 1],
        USA_after = CHNUSA_mom_cf[:, 1],
        China_before = CHNUSA_mom_or[:, 2],
        China_after = CHNUSA_mom_cf[:, 2],
    )

    # Effects on Peripheral Countries, in Percent Change
    mom_org = [
        #Trade Variables
        dict_out_org["mean_Pj"][all_partners]
        dict_out_org["mean_Πj"][all_partners]
        dict_out_org["mean_mij"][all_partners, USA_num, 1]
        dict_out_org["mean_mij"][all_partners, CHN_num, 1]
        dict_out_org["mean_mij"][USA_num, all_partners, 1]
        dict_out_org["mean_mij"][CHN_num, all_partners, 1]
        #Risk Measures
        no_diag_mean(dict_out_org["Sigma"], 2)[all_partners, 1]
        dict_out_org["R2"][USA_num, all_partners]
        dict_out_org["R2"][CHN_num, all_partners]
        rho_ij_or[USA_num, all_partners]
        rho_ij_or[CHN_num, all_partners]
    ]

    mom_cf = [
        #Trade Variables
        dict_out_cf["mean_Pj"][all_partners]
        dict_out_cf["mean_Πj"][all_partners]
        dict_out_cf["mean_mij"][all_partners, USA_num, 1]
        dict_out_cf["mean_mij"][all_partners, CHN_num, 1]
        dict_out_cf["mean_mij"][USA_num, all_partners, 1]
        dict_out_cf["mean_mij"][CHN_num, all_partners, 1]
        #Risk Measures
        no_diag_mean(dict_out_cf["Sigma"], 2)[all_partners, 1]
        dict_out_cf["R2"][USA_num, all_partners]
        dict_out_cf["R2"][CHN_num, all_partners]
        rho_ij_cf[USA_num, all_partners]
        rho_ij_cf[CHN_num, all_partners]
    ]

    table_out = round.(100.0 * (mom_cf ./ mom_org .- 1); digits = 2)
    Var_names2 = vec([
        L"P_j",
        L"Π_j",
        L"ω_{USA,j}",
        L"ω_{CHN,j}",
        L"ω_{i,USA}",
        L"ω_{i,CHN}",
        L"\overbar{σ_j}]",
        L"R^2_{i,USA}",
        L"R^2_{i,CHN}",
        L"ρ_{i,USA}",
        L"ρ_{i,CHN}",
    ])

    dt_partners = DataFrame(
        Variables = Var_names2,
        Canada = table_out[:, 1],
        Mexico = table_out[:, 2],
        UK = table_out[:, 3],
        Korea = table_out[:, 4],
        Malaysia = table_out[:, 5],
        Singapore = table_out[:, 6],
    )

    Benchmark_output = dict_out_org["Incomp_output"]

    # print outputs
    @info "Benchmark:                         ", Benchmark_output
    @info "US-China trade war counterfactual: ", CF_output
    @info dt_USACHN
    @info dt_partners

    return dt_USACHN, dt_partners
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    random_productivity_state(N, t, sigma, psi, years, dt, BURN_IN)

Generate T-period randome productivity paths for the N countries.

# Arguments
- `N`: Number of countries
- `t`: time periods StepRangeLen
- `sigma`: volatility parameter
- `psi`: mean-reversion parameter
- `years`: Number of years to simulate and then return
- `dt`: time step length (1/12 for monthly)
- `BURN_IN`: What portion of simulation to be dropped; to not start far from steady-state

# Examples
```jldoctest
julia> states = random_productivity_state(N, t, sigma, psi, years, dt, BURN_IN)
project_routines.shocks([0.871099412376654 0.9493402705040903 … 1.119913446658425;
                        0.618270399469909 0.6296256560735243 … 1.6540822971629776;
                        … ;
                        1.1571205561801143 1.08356700181851 … 1.1342460486978239;
                        0.5712339750758075 0.46131750459943976 … 0.9368051783030341],
                        600)
```
"""
function random_productivity_state(N, t, sigma, psi, years, dt, BURN_IN)
    rng = MersenneTwister(97)
    logz0 = zeros(N)
    draw_size = Int(years / (dt * (1 - BURN_IN)))
    draw_vec = (0:dt:draw_size-dt)
    T = Int(years / dt)
    B = randn(rng, N, length(draw_vec) - 1)
    logzj =
        repeat(exp.(-psi * draw_vec'), N, 1) .* logz0 +
        sigma * repeat(exp.(-psi * draw_vec'), N, 1) .* cumsum(
            [zeros(N) sqrt.(
                diff(repeat(exp.(2 * psi .* draw_vec') .- 1, N, 1), dims = 2),
            ) .* B],
            dims = 2,
        ) ./ sqrt(2 * psi)

    zj = exp.(logzj)

    keep_index = draw_size-T+1:draw_size
    length(keep_index)
    zj = zj[:, keep_index]
    states = shocks(zj, T)
    return states
end
# ---------------------------------------------------------



# ---------------------------------------------------------
# ---------------------------------------------------------
## Utility Functions
# ---------------------------------------------------------
"""
    triangle_inequality(Tau)

Make triangular inequality equation hold iteratively.

For all distinct countries i, j, and k, ``τ(ij) < τ(ik)*τ(kj)``
"""
function triangle_inequality(Tau)
    N = size(Tau)[1]
    lTau = log.(Tau)
    for i = 1:N
        jv = collect(1:N)
        jv = filter!(x -> x ≠ i, jv)
        for j in jv, k in jv
            trineq = lTau[i, k] <= lTau[i, j] + lTau[j, k]
            if trineq == false
                # println([i,j,k])
                minind = findall(
                    x ->
                        x == minimum([lTau[i, k], lTau[i, j], lTau[j, k]]),
                    [lTau[i, k], lTau[i, j], lTau[j, k]],
                )
                if minind == 1
                    lTau[i, k] = lTau[i, j] + lTau[j, k]
                elseif minind == 2
                    lTau[i, j] = lTau[i, k] - lTau[j, k]
                else
                    lTau[j, k] = lTau[i, k] - lTau[i, j]
                end
            end
        end
    end
    return (exp.(lTau))
end
# ---------------------------------------------------------


# ---------------------------------------------------------
function remove_diagonal(x)
    mat = Array{Float64}(undef, size(x, 1), size(x, 1) - 1)
    for i = 1:size(mat, 1)
        for j ∈ 1:size(mat, 2)
            if i > j
                mat[i, j] = x[i, j]
            else
                mat[i, j] = x[i, j+1]
            end
        end
    end
    return mat
end
# ---------------------------------------------------------


# ---------------------------------------------------------
function nan_diag(X)
    N = size(X, 1)
    X_o = copy(X)
    X_o[Int.(I(N)).==1] .= NaN
    return X_o
end
# ---------------------------------------------------------


# ---------------------------------------------------------
"""
    no_diag_mean(X, dim)

Return the mean, over dimension ``dim``, of off-diagonal elements of matrix X.

```jldoctest
julia> no_diag_mean([100 2;3 100],2)
2×1 Matrix{Float64}:
 2.0
 3.0
```
"""
function no_diag_mean(X, dim)
    N = size(X, 1)
    X_o = copy(X)
    X_o[Int.(I(N)).==1] .= 0.0
    avg_wom = sum(X_o, dims = dim) / (N - 1)
    return avg_wom
end
# ---------------------------------------------------------


# ---------------------------------------------------------
function central_finder(Xij, Yj)
    N = size(Yj, 1)
    cent = zeros(size(Yj))
    exp_share_world = zeros(size(Yj))

    for tp = 1:size(Yj, 2)
        exp_share_world[:, tp] =
            (N - 1) * no_diag_mean(Xij[:, :, tp], 1) ./
            sum((N - 1) * no_diag_mean(Xij[:, :, tp], 1)[:])
        SumXt =
            dropdims(Xij[:, :, tp:tp], dims = 3) +
            transpose(dropdims(Xij[:, :, tp:tp], dims = 3))
        Yt = Yj[:, tp:tp] .+ transpose(Yj[:, tp:tp])
        cent[:, tp] =
            (N - 1) *
            no_diag_mean((SumXt ./ Yt) .* exp_share_world[:, tp:tp], 1)
    end
    cent_out = mean(cent, dims = 2)
    central_ix = cent_out .> median(cent_out)
    return cent_out, central_ix
end
# ---------------------------------------------------------


# ---------------------------------------------------------
function summ_stat_univar(Xin)
    X = vec(filter(!isnan, Xin))
    out = OrderedDict(
        "Min" => minimum(X),
        "Prc(1)" => percentile(X, 1),
        "Prc(5)" => percentile(X, 5),
        "Med" => percentile(X, 50),
        "Avg" => mean(X),
        "Prc(95)" => percentile(X, 95),
        "Prc(99)" => percentile(X, 99),
        "Max" => maximum(X),
        "IQR" => iqr(X),
        "STD" => std(X),
    )

    return out
end
# ---------------------------------------------------------


# ---------------------------------------------------------
function tab_cfstat(X_before, X_after)
    Xb_vec = X_before[Int.(I(size(X_before)[1])).==Int(0)]
    Xa_vec = X_after[Int.(I(size(X_after)[1])).==Int(0)]
    return reshape([mean(Xb_vec) std(Xb_vec) mean(Xa_vec)], 1, :)
end
# ---------------------------------------------------------
# ---------------------------------------------------------


# ---------------------------------------------------------
end # module
# ---------------------------------------------------------